{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f55d90c3-5a96-408c-a672-87df10462e60",
   "metadata": {},
   "source": [
    "# Private information retrieval\n",
    "\n",
    "This notebook explains how to do PIR with Concrete, in a simple way, with applications to blocking spam phone numbers or bad URLs. The principle of PIR is that there is a non-encrypted large database on the server side, which we can't move to the client side for some reasons, like it's too big or we don't want to for privacy reasons or it's updated too often. With PIR, we let the user query the database, and the query (input and output) is not seen in the clear by the server."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "edb989d3-93b1-40cc-8ae4-5e739583c9bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Importing some libraries\n",
    "\n",
    "import argparse\n",
    "import random\n",
    "import time\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from concrete import fhe"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af0a4710-ed46-4d21-a992-21eb2ad36bd2",
   "metadata": {},
   "source": [
    "### One-hot vector\n",
    "\n",
    "Our database will be represented as a table T, and inputs to T will be `n_input_bits` bits. Cells of T are represented as `m_output_bits`.\n",
    "\n",
    "Before querying the database, the input `i` will be represented as [one-hot vectors](https://en.wikipedia.org/wiki/One-hot), ie a vector of `2**n_input_bits` bits: all of them are 0, but the one in the i-th position."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a1a50b5e-4c36-40e0-94b1-278fac66f322",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_one_hot_vector(index: int, size: int) -> np.ndarray:\n",
    "\n",
    "    one_hot_vector = np.zeros(shape=(size,), dtype=np.int8)\n",
    "    one_hot_vector[index] = 1\n",
    "    return one_hot_vector"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c265202-e5c4-4387-b67a-f9b8da1b2af8",
   "metadata": {},
   "source": [
    "Here are a few examples of one-hot vectors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "861fb83d-783c-42e3-a59e-6abfeb274ff2",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = make_one_hot_vector(0, size=5)\n",
    "assert np.array_equal(x, np.array([1, 0, 0, 0, 0]))\n",
    "\n",
    "x = make_one_hot_vector(2, size=7)\n",
    "assert np.array_equal(x, np.array([0, 0, 1, 0, 0, 0, 0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a3d6bae-dac6-41d3-94fe-153906b99aad",
   "metadata": {},
   "source": [
    "## Quering the database with a one-hot vector\n",
    "\n",
    "Quering the database will be very simple: the input will be given as a one-hot vector and is encrypted by the client. The database is in the clear on the server side. Just making a dot product with the database will return the right input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "87ea17f3-3c44-4f09-82d8-e5010f65bd8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_ith_element_of_database(one_hot_vector: np.ndarray, database: np.ndarray) -> int:\n",
    "    return np.dot(one_hot_vector, database)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c67607e7-a93f-4755-8226-17035c2667bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "database = np.array([3, 5, 21, 7, 11, 13, 2, 17])\n",
    "assert database.ndim == 1\n",
    "database_length = database.shape[0]\n",
    "database_output_bits = np.ceil(np.log2(np.max(database))).astype(np.int32)\n",
    "\n",
    "assert database_output_bits == 5\n",
    "\n",
    "# For now, we have not compiled our functions so here, all the computations\n",
    "# in the following asserts are done in the clear, just to check the semantic\n",
    "# of the functions\n",
    "assert (\n",
    "    get_ith_element_of_database(make_one_hot_vector(0, size=database_length), database)\n",
    "    == database[0]\n",
    ")\n",
    "assert (\n",
    "    get_ith_element_of_database(make_one_hot_vector(3, size=database_length), database)\n",
    "    == database[3]\n",
    ")\n",
    "assert (\n",
    "    get_ith_element_of_database(make_one_hot_vector(4, size=database_length), database)\n",
    "    == database[4]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ef38c1a-eab5-419a-bea0-aaaf1c66a6aa",
   "metadata": {},
   "source": [
    "## Private information retrieval with FHE\n",
    "\n",
    "Now, let's make that in a private way, without the server seing the query in the clear. First we compile the function with Concrete."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "91c766c2-d30a-4e67-bf12-25e4083c67c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Computation Graph\n",
      "--------------------------------------------------------------------------------\n",
      "%0 = one_hot_vector        # EncryptedTensor<uint1, shape=(8,)>        ∈ [0, 1]\n",
      "%1 = database              # ClearTensor<uint5, shape=(8,)>            ∈ [2, 21]\n",
      "%2 = dot(%0, %1)           # EncryptedScalar<uint5>                    ∈ [2, 21]\n",
      "return %2\n",
      "--------------------------------------------------------------------------------\n",
      "\n",
      "MLIR\n",
      "--------------------------------------------------------------------------------\n",
      "module {\n",
      "  func.func @main(%arg0: tensor<8x!FHE.eint<5>>, %arg1: tensor<8xi6>) -> !FHE.eint<5> {\n",
      "    %0 = \"FHELinalg.to_signed\"(%arg0) : (tensor<8x!FHE.eint<5>>) -> tensor<8x!FHE.esint<5>>\n",
      "    %1 = \"FHELinalg.dot_eint_int\"(%0, %arg1) : (tensor<8x!FHE.esint<5>>, tensor<8xi6>) -> !FHE.esint<5>\n",
      "    %2 = \"FHE.to_unsigned\"(%1) : (!FHE.esint<5>) -> !FHE.eint<5>\n",
      "    return %2 : !FHE.eint<5>\n",
      "  }\n",
      "}\n",
      "--------------------------------------------------------------------------------\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def compile_function(database, **kwargs):\n",
    "    assert database.ndim == 1\n",
    "    database_length = database.shape[0]\n",
    "    inputset_length = 100\n",
    "    inputset = [\n",
    "        (make_one_hot_vector(np.random.randint(database_length), database_length), database)\n",
    "        for _ in range(inputset_length)\n",
    "    ]\n",
    "    # Also add the extreme value, which would actually be sufficient alone, as it reaches the maximal\n",
    "    # values in get_ith_element_of_database\n",
    "    inputset.append((make_one_hot_vector(np.argmax(database), database_length), database))\n",
    "\n",
    "    compiler = fhe.Compiler(\n",
    "        get_ith_element_of_database, {\"one_hot_vector\": \"encrypted\", \"database\": \"clear\"}\n",
    "    )\n",
    "    circuit = compiler.compile(inputset, **kwargs)\n",
    "\n",
    "    return circuit\n",
    "\n",
    "\n",
    "circuit = compile_function(database, show_mlir=True, show_graph=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28d7e4d7-547b-4cfc-a34f-7c01e5368296",
   "metadata": {},
   "source": [
    "Then we can make inferences over encrypted input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5c686f23-5774-40c3-88bc-f90417b47f7e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FHE computation done in 2.6 milliseconds -- database is 8 (2**3) elements of 5 bits\n"
     ]
    }
   ],
   "source": [
    "def test_encrypted_queries(database, circuit, how_many_tests=1, verbose=True):\n",
    "\n",
    "    times = []\n",
    "\n",
    "    for _ in range(how_many_tests):\n",
    "        database_length = database.shape[0]\n",
    "        log_database_length = np.ceil(np.log2(database_length)).astype(np.int32)\n",
    "\n",
    "        # Random index in the database\n",
    "        random_index = np.random.randint(database_length)\n",
    "\n",
    "        # Turn it into one hot vector\n",
    "        x = make_one_hot_vector(random_index, database_length)\n",
    "\n",
    "        # Encrypt the query, on the client side\n",
    "        encrypted_x, _ = circuit.encrypt(x, None)\n",
    "\n",
    "        # Run the FHE computation on the server side\n",
    "        time_begin = time.time()\n",
    "        encrypted_y = circuit.run(encrypted_x, database)\n",
    "        time_end = time.time()\n",
    "        times.append(time_end - time_begin)\n",
    "\n",
    "        if verbose:\n",
    "            print(\n",
    "                f\"FHE computation done in {(time_end - time_begin) * 1000:.1f} milliseconds -- database is {database_length} (2**{log_database_length}) elements of {database_output_bits} bits\"\n",
    "            )\n",
    "\n",
    "        # Decrypt the result on the client side\n",
    "        y = circuit.decrypt(encrypted_y)\n",
    "\n",
    "        # And check the computations worked fine\n",
    "        assert y == database[random_index]\n",
    "\n",
    "    return times\n",
    "\n",
    "\n",
    "_ = test_encrypted_queries(database, circuit)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9716a15-0797-4270-b08a-7d1db97323e2",
   "metadata": {},
   "source": [
    "## Performances\n",
    "\n",
    "Now, obviously, we can do it for much larger databases. Let's see the performance!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3a573631-0bdf-4969-85c1-c5da9afcc2bf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "For a database of 2** 4 elements of  8 bits, average execution time is 1.2 milliseconds\n",
      "For a database of 2** 4 elements of 16 bits, average execution time is 1.3 milliseconds\n",
      "For a database of 2** 8 elements of  8 bits, average execution time is 4.5 milliseconds\n",
      "For a database of 2** 8 elements of 16 bits, average execution time is 5.0 milliseconds\n",
      "For a database of 2** 9 elements of  8 bits, average execution time is 6.8 milliseconds\n",
      "For a database of 2** 9 elements of 16 bits, average execution time is 13.4 milliseconds\n",
      "For a database of 2**10 elements of  4 bits, average execution time is 8.1 milliseconds\n",
      "For a database of 2**10 elements of  8 bits, average execution time is 13.3 milliseconds\n",
      "For a database of 2**12 elements of  4 bits, average execution time is 36.2 milliseconds\n",
      "For a database of 2**12 elements of  8 bits, average execution time is 55.1 milliseconds\n",
      "For a database of 2**14 elements of  4 bits, average execution time is 156.5 milliseconds\n",
      "For a database of 2**14 elements of  8 bits, average execution time is 221.3 milliseconds\n"
     ]
    }
   ],
   "source": [
    "how_many_tests = 10\n",
    "\n",
    "sample_list = [\n",
    "    (4, 8),\n",
    "    (4, 16),\n",
    "    (8, 8),\n",
    "    (8, 16),\n",
    "    (9, 8),\n",
    "    (9, 16),\n",
    "    (10, 4),\n",
    "    (10, 8),\n",
    "    (12, 4),\n",
    "    (12, 8),\n",
    "    (14, 4),\n",
    "    (14, 8),\n",
    "]\n",
    "timings_dic = {}\n",
    "\n",
    "for database_input_bits, database_output_bits in sample_list:\n",
    "\n",
    "    # Take a random database of expected size and output_bits\n",
    "    database_length = 2**database_input_bits\n",
    "    database = np.array(\n",
    "        [np.random.randint(2**database_output_bits) for _ in range(database_length)]\n",
    "    )\n",
    "\n",
    "    circuit = compile_function(database, show_mlir=False, show_graph=False)\n",
    "\n",
    "    # Benchmark\n",
    "    times = test_encrypted_queries(database, circuit, how_many_tests=how_many_tests, verbose=False)\n",
    "    mean_time = np.mean(times)\n",
    "    timings_dic[(database_input_bits, database_output_bits)] = mean_time\n",
    "    print(\n",
    "        f\"For a database of 2**{str(database_input_bits):>2s} elements of {str(database_output_bits):>2s} bits, average execution time is {1000 * mean_time:.1f} milliseconds\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9370a24c-2e28-4a6b-ba66-97aa548d7cd7",
   "metadata": {},
   "source": [
    "## Using several tables\n",
    "\n",
    "Finally, let's remark that, to store more information without paying the price of having large database_output_bits, we can apply several dot product in the function, and, on the client side, concatenate the results. Let's do an example with 4 subtables, to gain 2 extra bits on the output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fe2d50a5-0ba4-4433-9573-cf83c98b7db0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "For a database of 2**14 elements of 32 bits with 4 sub-databases, average execution time is 276.4 milliseconds\n"
     ]
    }
   ],
   "source": [
    "# We'll have 4 subdatabases of 8b each\n",
    "number_of_subdatabases = 4\n",
    "database_output_bits_subdatabases = 8\n",
    "\n",
    "database_output_bits = database_output_bits * number_of_subdatabases\n",
    "\n",
    "\n",
    "def get_ith_element_of_database(\n",
    "    one_hot_vector: np.ndarray,\n",
    "    database0: np.ndarray,\n",
    "    database1: np.ndarray,\n",
    "    database2: np.ndarray,\n",
    "    database3: np.ndarray,\n",
    ") -> (int, int, int, int):\n",
    "    return (\n",
    "        np.dot(one_hot_vector, database0),\n",
    "        np.dot(one_hot_vector, database1),\n",
    "        np.dot(one_hot_vector, database2),\n",
    "        np.dot(one_hot_vector, database3),\n",
    "    )\n",
    "\n",
    "\n",
    "# Take a random database of expected size and output_bits\n",
    "database_length = 2**database_input_bits\n",
    "database = np.array([np.random.randint(2**database_output_bits) for _ in range(database_length)])\n",
    "\n",
    "database0 = (database >> 0) & 0xFF\n",
    "database1 = (database >> 8) & 0xFF\n",
    "database2 = (database >> 16) & 0xFF\n",
    "database3 = (database >> 24) & 0xFF\n",
    "\n",
    "\n",
    "def compile_function_split_database(database0, database1, database2, database3, **kwargs):\n",
    "    database_length = database0.shape[0]\n",
    "    inputset_length = 100\n",
    "    inputset = [\n",
    "        (\n",
    "            make_one_hot_vector(np.random.randint(database_length), database_length),\n",
    "            database0,\n",
    "            database1,\n",
    "            database2,\n",
    "            database3,\n",
    "        )\n",
    "        for _ in range(inputset_length)\n",
    "    ]\n",
    "    compiler = fhe.Compiler(\n",
    "        get_ith_element_of_database,\n",
    "        {\n",
    "            \"one_hot_vector\": \"encrypted\",\n",
    "            \"database0\": \"clear\",\n",
    "            \"database1\": \"clear\",\n",
    "            \"database2\": \"clear\",\n",
    "            \"database3\": \"clear\",\n",
    "        },\n",
    "    )\n",
    "    circuit = compiler.compile(inputset, **kwargs)\n",
    "    return circuit\n",
    "\n",
    "\n",
    "circuit = compile_function_split_database(\n",
    "    database0, database1, database2, database3, show_mlir=False, show_graph=False\n",
    ")\n",
    "\n",
    "\n",
    "def test_encrypted_queries_split_database(database, circuit, how_many_tests=1, verbose=True):\n",
    "\n",
    "    times = []\n",
    "\n",
    "    for _ in range(how_many_tests):\n",
    "        database_length = database.shape[0]\n",
    "        log_database_length = np.ceil(np.log2(database_length)).astype(np.int32)\n",
    "\n",
    "        # Random index in the database\n",
    "        random_index = np.random.randint(database_length)\n",
    "\n",
    "        # Turn it into one hot vector\n",
    "        x = make_one_hot_vector(random_index, database_length)\n",
    "\n",
    "        # Encrypt the query, on the client side\n",
    "        encrypted_x, _, _, _, _ = circuit.encrypt(x, None, None, None, None)\n",
    "\n",
    "        # Run the FHE computation on the server side\n",
    "        time_begin = time.time()\n",
    "        encrypted_y = circuit.run(encrypted_x, database0, database1, database2, database3)\n",
    "        time_end = time.time()\n",
    "        times.append(time_end - time_begin)\n",
    "\n",
    "        if verbose:\n",
    "            print(\n",
    "                f\"FHE computation done in {(time_end - time_begin) * 1000:.1f} milliseconds -- database is {database_length} (2**{log_database_length}) elements of {database_output_bits} bits\"\n",
    "            )\n",
    "\n",
    "        # Decrypt the result on the client side\n",
    "        y_bits = circuit.decrypt(encrypted_y)\n",
    "\n",
    "        # Relinearize the result\n",
    "        y = y_bits[0] + y_bits[1] * 256 + y_bits[2] * 256**2 + y_bits[3] * 256**3\n",
    "\n",
    "        # And check the computations worked fine\n",
    "        assert (\n",
    "            y == database[random_index]\n",
    "        ), f\"{y} {y:x} {y_bits} {y_bits[0]:x} {y_bits[1]:x} {y_bits[2]:x} {y_bits[3]:x} {database[random_index]:x}\"\n",
    "\n",
    "    return times\n",
    "\n",
    "\n",
    "# Benchmark\n",
    "times = test_encrypted_queries_split_database(\n",
    "    database, circuit, how_many_tests=how_many_tests, verbose=False\n",
    ")\n",
    "print(\n",
    "    f\"For a database of 2**{str(database_input_bits):>2s} elements of {str(database_output_bits):>2s} bits with 4 sub-databases, average execution time is {1000 * np.mean(times):.1f} milliseconds\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76855d6d-a65d-442c-b428-f1a077c07570",
   "metadata": {},
   "source": [
    "## Use-cases for phone spamming\n",
    "\n",
    "Now, let see where PIR could be used. Let's imagine we want to build a spam database. In France, there are 10 ** 9 ~ 2 ** 30 phone numbers, we could have a database T[i] for i an integer of 30 bits, returning a boolean stating if the phone number is a known spam number. The database would be server side, and often updated. Phones could query the database on an number, and if the result is positive, filter the call as a spam. All of this would be done without the server knowing the calling numbers.\n",
    "\n",
    "Then, the goal is to represent this table T as a database D:\n",
    "- with inputs of bitsize database_input_bits\n",
    "- with number_of_subdatabases subdatabases, all of which outputs database_output_bits_subdatabases bits of information\n",
    "such that bitsize database_input_bits + log2(database_output_bits * number_of_subdatabases) >= 30.\n",
    "\n",
    "Let's suppose the phone number to query is N, where N is represented as N0 || N1, where N0 is the first database_input_bits bits of N, and N1 are the remaining bits.\n",
    "\n",
    "The user would query D with N0. She would receive database_output_bits * number_of_subdatabases bits of information. By looking at the bit as the N1-th position, he would know if the phone number is a spam. \n",
    "\n",
    "For example, we can take\n",
    "- database_input_bits = 14\n",
    "- number_of_subdatabases = 8192\n",
    "- database_output_bits = 8\n",
    "\n",
    "Obviously, the goal is to find the combination fulfilling the condition such that the execution time is smaller. \n",
    "\n",
    "Another possibility to make this less computing intensive is to let some of the bits of the phone number in the clear, and just hide the remaining bits. Eg, we could have the first 3 digits in the clear, and just hide the last 6 digits, turning the condition to bitsize database_input_bits + log2(database_output_bits * number_of_subdatabases) >= 20.\n",
    "\n",
    "Below, we exhaust to find the best combination, for 30 and 20b phones.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9dbf947e-4703-432e-9195-5e0e38760f3a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Estimated time would be  10070.0 seconds for  8388608 DBs of (4, 8)\n",
      "Estimated time would be   5356.0 seconds for  4194304 DBs of (4, 16)\n",
      "Estimated time would be   2362.0 seconds for   524288 DBs of (8, 8)\n",
      "Estimated time would be   1308.0 seconds for   262144 DBs of (8, 16)\n",
      "Estimated time would be   1775.0 seconds for   262144 DBs of (9, 8)\n",
      "Estimated time would be   1757.0 seconds for   131072 DBs of (9, 16)\n",
      "Estimated time would be   2135.0 seconds for   262144 DBs of (10, 4)\n",
      "Estimated time would be   1740.0 seconds for   131072 DBs of (10, 8)\n",
      "Estimated time would be   2372.0 seconds for    65536 DBs of (12, 4)\n",
      "Estimated time would be   1806.0 seconds for    32768 DBs of (12, 8)\n",
      "Estimated time would be   2564.0 seconds for    16384 DBs of (14, 4)\n",
      "Estimated time would be   1814.0 seconds for     8192 DBs of (14, 8)\n",
      "\n",
      "Best combination: 1308.0 seconds for a DB of 30 bits\n",
      "\n",
      "Estimated time would be     10.0 seconds for     8192 DBs of (4, 8)\n",
      "Estimated time would be      6.0 seconds for     4096 DBs of (4, 16)\n",
      "Estimated time would be      3.0 seconds for      512 DBs of (8, 8)\n",
      "Estimated time would be      2.0 seconds for      256 DBs of (8, 16)\n",
      "Estimated time would be      2.0 seconds for      256 DBs of (9, 8)\n",
      "Estimated time would be      2.0 seconds for      128 DBs of (9, 16)\n",
      "Estimated time would be      3.0 seconds for      256 DBs of (10, 4)\n",
      "Estimated time would be      2.0 seconds for      128 DBs of (10, 8)\n",
      "Estimated time would be      3.0 seconds for       64 DBs of (12, 4)\n",
      "Estimated time would be      2.0 seconds for       32 DBs of (12, 8)\n",
      "Estimated time would be      3.0 seconds for       16 DBs of (14, 4)\n",
      "Estimated time would be      2.0 seconds for        8 DBs of (14, 8)\n",
      "\n",
      "Best combination: 2.0 seconds for a DB of 20 bits\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Finding the best combination\n",
    "def find_best_combination(expected_total_bits):\n",
    "    best_combination = None\n",
    "\n",
    "    for database_input_bits, database_output_bits in timings_dic.keys():\n",
    "        remaining_bits = expected_total_bits - database_input_bits\n",
    "        assert remaining_bits > 0\n",
    "        number_of_subdatabases = np.ceil(2**remaining_bits / database_output_bits).astype(np.int32)\n",
    "        estimated_time = np.ceil(\n",
    "            number_of_subdatabases * timings_dic[(database_input_bits, database_output_bits)]\n",
    "        )\n",
    "\n",
    "        print(\n",
    "            f\"Estimated time would be {str(estimated_time):>8s} seconds for {str(number_of_subdatabases):>8s} DBs of {(database_input_bits, database_output_bits)}\"\n",
    "        )\n",
    "\n",
    "        if best_combination is None or estimated_time < best_combination[0]:\n",
    "            best_combination = (\n",
    "                estimated_time,\n",
    "                number_of_subdatabases,\n",
    "                database_input_bits,\n",
    "                database_output_bits,\n",
    "            )\n",
    "\n",
    "    print(\n",
    "        f\"\\nBest combination: {best_combination[0]} seconds for a DB of {expected_total_bits} bits\\n\"\n",
    "    )\n",
    "\n",
    "\n",
    "find_best_combination(30)\n",
    "find_best_combination(20)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff578b6f-98b9-4073-8427-ec32aff5437a",
   "metadata": {},
   "source": [
    "## Another use-case for URL checking\n",
    "\n",
    "It might also be tempting to keep (and refresh very often) a list of bad URL on the server side, and to use them to protect user to click on bad links. Of course, there will be too many URLs to keep with the previous system: fortunately we have an hash-based solution for this. \n",
    "\n",
    "The principle will be to use a small non-cryptographic hash function, which maps strings to small integers, let say 20-bit integers. Then, any time the server would see a bad URL `u`, it would hash it to `h` and would store `T[h] = 1` to set that this hash is potentially dangerous. Then, with our system, the user can use privacy-preserving PIR to know if a given URL `u'` is dangerous, by having it to `h'` and checking if `T[h'] == 1`. \n",
    "\n",
    "As we know with such-a-small hash function, there will be collisions, which means that sometimes, the user will receive false positives: having `T[h] == 1` doesn't mean that this given URL is dangerous, but that there exists an URL with same hash which is dangerous. These collisions is not a problem per se, the user may just see a \"Warning, this URL is potentially dangerous\" but still access if he is confident. Or, we could use several different hash functions and different tables `T_i`, and we would check if all `T_i` return 1 to define if an URL is a spam, to highly reduce the probability of collisions. "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
